# _*_ coding: utf-8 _*_
# @Date : 2023/3/22 17:25
# @Author : Paul
# @File : linear_regression.py
# @Description :
import matplotlib.pyplot as plt
import numpy as np
import sys
import json

from core.utils.log_util import LogUtil
from core.beans.regress_result import RegressionResult
from core.utils.string_utils import StringUtils
from core.utils.date_util import DateUtil
from sklearn.model_selection import cross_val_score
from regressions.regression import Regression
from sklearn.linear_model import LinearRegression as LR


class TWLinearRegression(Regression):

    def __init__(self,
                 app_name="linear_regression",
                 data_source_id=None,
                 table_name=None,
                 feature_cols=None,
                 class_col=None,
                 train_size=None,
                 param=None
                 ):
        """
        初始化
        :param app_name:
        :param data_source_id:
        :param table_name:
        :param feature_cols:
        :param class_col:
        :param train_size:
        """
        super(TWLinearRegression, self).__init__(app_name=app_name,
                                                   data_source_id=data_source_id,
                                                   table_name=table_name,
                                                   feature_cols=feature_cols,
                                                   class_col=class_col,
                                                   train_size=train_size,
                                                   param=param)
        self.IS_MODEL_EVAL = True  # 默认：不需要评估模型
        # 预测效果分布图
        self.tree_pred_image = "descion_tree_pred_" + app_name + "_" + DateUtil.getCurrentDateSimple()
        # 画图数据条数
        self.plot_rows_num = 200

    def initModel(self):
        """
        初始化模型
        """
        algoParam = self.param["algoParam"]
        fit_intercept = True if StringUtils.isBlack(algoParam["fitIntercept"]) else bool(algoParam["fitIntercept"])
        normalize = False if StringUtils.isBlack(algoParam["normalize"]) else bool(algoParam["normalize"])
        copy_X = True if StringUtils.isBlack(algoParam["copyX"]) else bool(algoParam["copyX"])
        n_jobs = None if StringUtils.isBlack(algoParam["nJobs"]) else int(algoParam["nJobs"])

        self.reg = LR(fit_intercept=fit_intercept,
                      normalize=normalize,
                      copy_X=copy_X,
                      n_jobs=n_jobs)

    def buildModel(self, train_data):
        """
        训练模型
        """
        Xtrain = train_data[0]
        Ytrain = train_data[1]
        self.reg = self.reg.fit(Xtrain, Ytrain)

    def evalModel(self, train_data, test_data):
        """
        评估模型
        """
        Xtest = test_data[0]
        Ytest = test_data[1]
        Ytest_predict = self.reg.predict(Xtest)
        from sklearn.metrics import mean_squared_error as MSE
        self.score_ = np.sqrt(MSE(Ytest, Ytest_predict))
        var_importance = {}
        for i in range(len(self.feature_cols)):
            var_importance[self.feature_cols[i]] = self.reg.coef_[0][i]

        test_data_rows = len(Xtest)

        plt.figure()
        if test_data_rows < self.plot_rows_num:
            plt.plot(np.linspace(0.05, 1, test_data_rows), Ytest, "green", label="Y-真实值")
            plt.plot(np.linspace(0.05, 1, test_data_rows), Ytest_predict, "red", label="Y-预测值")
        else:
            plt.plot(np.linspace(0.05, 1, self.plot_rows_num), Ytest[:self.plot_rows_num], "green", label="Y-真实值")
            plt.plot(np.linspace(0.05, 1, self.plot_rows_num), Ytest_predict[:self.plot_rows_num], "red", label="Y-预测值")

        plt.legend()
        plt.savefig(self.regress_pred_image, dpi=300)
        plt.show()

        # 结束时间
        end_time = DateUtil.getCurrentDate()
        cost_second = DateUtil.diffMin(self.start_time, end_time)

        # 模型结果存入mysql
        algo_result = RegressionResult(self.param["id"],
                                       "linear_regression",
                                       self.param,
                                       self.app_name,
                                       self.info,
                                       self.describe,
                                       self.regress_pred_image,
                                       json.dumps(var_importance).replace("\"", "'"),
                                       self.score_,
                                       "sucess",
                                       self.start_time,
                                       end_time,
                                       cost_second)
        LogUtil.saveRegressionResult(self.meta_data_source, algo_result)


if __name__ == '__main__':
    argv = sys.argv[1]
    # argv = "{\"algoParam\":{\"copyX\":\"true\",\"fitIntercept\":\"true\",\"nJobs\":\"\",\"normalize\":\"true\"},\"appName\":\"kmeans_1\",\"classCols\":\"housePrice\",\"dataSourceId\":\"9\",\"featureCols\":\"MedInc,HouseAge,AveRooms,AveBedrms,Population,AveOccup,Latitude,Longitude\",\"id\":\"1679540469249\",\"preProcessMethodList\":[{\"preProcessMethod\":\"deletena\"}],\"tableName\":\"california_housing\",\"trainDataRatio\":\"0.7\"}"
    param = json.loads(argv)
    app_name = param["appName"]
    data_source_id = param["dataSourceId"]
    table_name = param["tableName"]
    feature_cols = param["featureCols"]
    class_cols = param["classCols"]
    train_size = float(param["trainDataRatio"])
    class_cols_list = []
    if isinstance(class_cols, list):
        class_cols_list = class_cols
    else:
        class_cols_list.append(class_cols)
    classifier = TWLinearRegression(app_name=app_name,
                                      data_source_id=data_source_id,
                                      table_name=table_name,
                                      feature_cols=feature_cols,
                                      class_col=class_cols_list,
                                      train_size=train_size,
                                      param=param)
    if "paramTrain" not in param.keys():
        classifier.execute()
    else:
        classifier.paramTrain()